#include "util.hpp"
#include <iostream>
#include <vector>


namespace nn {

using namespace std;
/*
 * size - 样本输入规模， 例如每组数据有3个元素({1 1 0}) ,
 * 则代表该感知器有2个输入 input_type 输入数据类型 ，
 * 该类型必须为可计算类型，默认为double predict_type 输出数据类型，
 * 该类型必须为可计算类型，默认为double
 */
template <int size, typename input_type = double,
          typename predict_type = double>
class Perceptron {

private:
  // 每组测试数据的最后一数为期望的输出，故权重列表的长度应-1
  int inputSize;
  int trained_count;
  vector<double> W;
  double bias;
  predict_type (*Func)(double);
  void arithmetic_test() {
    // predict_type predict_test;
    input_type input_test;
    // arithmetic_constraint(predict_type);
    arithmetic_assert(input_test);
    arithmetic_constraint(input_test); // 要求所有类型都必须是可运算类型
  }

public:
  Perceptron(predict_type (*activation_func)(double)) : inputSize(size - 1) {
    arithmetic_test();
    bias = 0;
    trained_count = 0;
    for (size_t i = 0; i < inputSize; i++)
      W.push_back(0);

    Func = activation_func;
  };

  void train(const vector<vector<input_type>> &sample, int times, double rate) {
    int sample_num = sample.size();

    for (int trainNum = 0; trainNum < times; trainNum++) {
      for (int i = 0; i < sample_num; i++) {
        double sum = bias;

        for (int j = 0; j < inputSize; j++)
          sum += sample[i][j] * W[j];

        sum = Func(sum);                  //迭代结果
        sum = sample[i][inputSize] - sum; //与真实值的差

        for (int j = 0; j < inputSize; j++)
          W[j] += rate * sum * sample[i][j];

        bias += rate * sum; //更新权值
      }
      trained_count++;
    }
  }
  predict_type predict(const vector<input_type> &input) {
    double output = bias;
    for (unsigned int i = 0; i < input.size(); i++)
      output += input[i] * W[i];
    return Func(output);
  }

  void info() {
    cout << "input: " << inputSize << endl << "Weight: [";
    for (size_t i = 0; i < inputSize; i++)
      cout << W[i] << " ";
    cout << "]" << endl << " bias: " << bias << endl;
    cout << "train time: " << trained_count << endl;
  }

  ~Perceptron(){};
}; // Preceptron class

}; // namespace net